#include "../include/CentralCache.h"

#include "../include/PageCache.h"

#include <atomic>
#include <cassert>
#include <cstddef>
#include <thread>

namespace mp {

static const size_t SPAN_PAGES = 8;  // 每次从PageCache获取span大小（以页为单位）

void *CentralCache::fetchRange(
    size_t idx,
    size_t batchNum) {  // 索引检查，当索引大于等于FREE_LIST_SIZE时，说明申请内存过大应直接向系统申请
    if (idx >= FREE_LIST_SIZE || batchNum == 0)
        return nullptr;

    // 自旋锁保护
    while (locks_[idx].test_and_set(std::memory_order_acquire)) {
        std::this_thread::yield();  // 添加线程让步，避免忙等待，避免过度消耗CPU
    }

    void *result = nullptr;
    try {
        // 尝试从中心缓存获取内存块
        result = centralFreeList_[idx].load(std::memory_order_relaxed);

        if (!result) {
            // 如果中心缓存为空，从页缓存获取新的内存块
            size_t size = (idx + 1) * ALIGNMENT;
            result = fetchFromPageCache(size);

            if (!result) {
                locks_[idx].clear(std::memory_order_release);
                return nullptr;
            }

            // 将从PageCache获取的内存块切分成小块
            char *start = static_cast<char *>(result);
            size_t totalBlocks = (SPAN_PAGES * PageCache::PAGE_SIZE) / size;
            size_t allocBlocks = std::min(batchNum, totalBlocks);

            // 构建返回给ThreadCache的内存块链表
            if (allocBlocks > 1) {
                // 确保至少有两个块才构建链表
                // 构建链表
                for (size_t i = 1; i < allocBlocks; ++i) {
                    void *current = start + (i - 1) * size;
                    void *next = start + i * size;
                    *reinterpret_cast<void **>(current) = next;
                }
                *reinterpret_cast<void **>(start + (allocBlocks - 1) * size) = nullptr;
            }

            // 构建保留在CentralCache的链表
            if (totalBlocks > allocBlocks) {
                void *remainStart = start + allocBlocks * size;
                for (size_t i = allocBlocks + 1; i < totalBlocks; ++i) {
                    void *current = start + (i - 1) * size;
                    void *next = start + i * size;
                    *reinterpret_cast<void **>(current) = next;
                }
                *reinterpret_cast<void **>(start + (totalBlocks - 1) * size) = nullptr;

                centralFreeList_[idx].store(remainStart, std::memory_order_release);
            }
        } else  // 如果中心缓存有index对应大小的内存块
        {
            // 从现有链表中获取指定数量的块
            void *current = result;
            void *prev = nullptr;
            size_t count = 0;

            while (current && count < batchNum) {
                prev = current;
                current = *reinterpret_cast<void **>(current);
                count++;
            }

            if (prev)  // 当前centralFreeList_[index]链表上的内存块大于batchNum时需要用到
            {
                *reinterpret_cast<void **>(prev) = nullptr;
            }

            centralFreeList_[idx].store(current, std::memory_order_release);
        }
    } catch (...) {
        locks_[idx].clear(std::memory_order_release);
        throw;
    }

    // 释放锁
    locks_[idx].clear(std::memory_order_release);
    return result;
}

void CentralCache::returnRange(void *start, size_t size, size_t idx) {
    if (!start || idx >= FREE_LIST_SIZE)  // 当 idx >= FREE_LIST_SIZE 时，说明内存过大应该直接向系统归还
        return;

    while (locks_[idx].test_and_set(std::memory_order_acquire)) {
        std::this_thread::yield();
    }

    try {
        void *end = start;  // 找到要归还的链表的最后一个节点
        size_t cnt = 1;
        while (*reinterpret_cast<void **>(end) != nullptr && cnt < size) {
            end = *reinterpret_cast<void **>(end);
            ++cnt;
        }

        void *cur =
            centralFreeList_[idx].load(std::memory_order_relaxed);  // 将归还的链表连接到 CentralCache 的链表头部
        *reinterpret_cast<void **>(end) = cur;                      // 将原链表头部归还链表的尾部
        centralFreeList_[idx].store(start, std::memory_order_release);  // 将归还的链表头设为新的链表头
    } catch (...) {
        locks_[idx].clear(std::memory_order_release);
        throw;
    }

    locks_[idx].clear(std::memory_order_release);
}

void *CentralCache::fetchFromPageCache(size_t size) {
    size_t numPages = (size + PageCache::PAGE_SIZE - 1) / PageCache::PAGE_SIZE;  // 计算实际需要的页数

    if (size <= SPAN_PAGES * PageCache::PAGE_SIZE) {  // <= 32KB 的请求，使用固定 8 页
        return PageCache::getInstance().allocateSpan(SPAN_PAGES);
    } else {  // > 32KB 的请求，按实际需求分配
        return PageCache::getInstance().allocateSpan(numPages);
    }
}

}  // namespace mp